Hierarchical Embeddings for Disease-Signature Associations

Learning Psi Parameters Through Attention Mechanisms

Sarah Urbut, MD PhD

2025-10-08

Motivation: Current Psi Learning

Current approach: - Learn \(\psi \in \mathbb{R}^{K \times D}\) directly - \(\psi_{k,d}\) = association strength between signature \(k\) and disease \(d\)

Problems: - No semantic structure between diseases - Difficult to generalize to new diseases - No interpretable disease relationships

Simplified Hierarchical Embedding Approach

flowchart LR
    A[Diseases] --> B[Embeddings E_d ∈ ℝ^L]
    C[Signatures] --> D[Embeddings E_k ∈ ℝ^L]
    B --> E[ψ_k,d = E_d^T W_a E_k / √L]
    D --> E

    style A fill:#e1f5fe
    style B fill:#e8f5e8
    style C fill:#e1f5fe
    style D fill:#e8f5e8
    style E fill:#ffebee

Mathematical Formulation: Direct Approach

Step 1: Disease and Signature Embeddings \[E_d \in \mathbb{R}^L \quad \text{(disease embeddings)}\] \[E_k \in \mathbb{R}^L \quad \text{(signature embeddings)}\]

Step 2: Direct Psi Computation \[\psi_{k,d} = \frac{E_d^T W_a E_k}{\sqrt{L}} \in \mathbb{R}\]

where \(W_a \in \mathbb{R}^{L \times L}\) is learnable.

Key Insight

Skip attention softmax and contextualization - use raw attention scores directly as ψ values.

Tensor Dimensions

Variable Shape Description
\(E_d\) [D, L] Disease embeddings
\(E_k\) [K, L] Signature embeddings
\(W_a\) [L, L] Attention interaction matrix
\(\psi\) [K, D] Final association strengths
  • \(D\) = number of diseases
  • \(K\) = number of signatures
  • \(L\) = embedding dimension

Visualizing the embeddings of Sigs and Diseases

Visualizing the embeddings of Sigs and Diseases

Disease Embeddings Visualization

Disease and signature embeddings in PCA space showing clustering structure

Attention Matrix Structure

Attention matrix W showing block-diagonal structure

Psi Parameters Heatmap

Psi parameters showing disease-signature associations

Direct Embedding Interpretation

What \(E_d\) tells us: Each disease is represented by an \(L\)-dimensional vector capturing its biological characteristics.

What \(E_k\) tells us: Each signature is represented by an \(L\)-dimensional vector capturing pathway characteristics.

What \(W_a\) tells us: Learns which disease features should interact with which signature features.

What \(\psi_{k,d}\) tells us: Direct measure of how strongly disease \(d\) associates with signature \(k\) after \(W_a\) transformation.

Implementation in PyTorch

Key Components

# Embeddings
self.disease_embeddings = nn.Embedding(D, L)
self.signature_embeddings = nn.Embedding(K, L)

# Attention interaction matrix
self.attention_matrix = nn.Parameter(torch.randn(L, L))
def compute_psi_from_embeddings(self):
    E_d = self.disease_embeddings(torch.arange(D))  # [D, L]
    E_k = self.signature_embeddings(torch.arange(K))  # [K, L]

    # Direct psi computation
    psi_scores = torch.matmul(
        torch.matmul(E_d, self.attention_matrix),  # [D, L]
        E_k.T                                      # [L, K]
    ) / math.sqrt(L)  # [D, K]

    return psi_scores.T  # [K, D]

Parameter Efficiency Analysis

Parameter comparison between direct and hierarchical approaches

Advantages Over Direct Psi Learning

Parameter Efficiency: - Direct: \(O(D \times K)\) - Hierarchical: \(O(D \times L + K \times L + L^2)\) - When \(L \ll D\), much more efficient

Generalization: - New diseases without retraining signatures - Learned embeddings capture semantic relationships - Attention weights reveal interpretable associations

Regularization: - Embedding space provides smooth representations - Attention encourages sparse associations - Implicit regularization through shared representations

Extension: Medication Effects

Medication-Aware Embeddings: \[E_d(m) = E_d^{base} + M_d^{med} \cdot m_d\] \[E_k(m) = E_k^{base} + E_k^{med} \cdot m_k\]

Medication-Aware Attention: \[A_{d,k}(m) = \text{softmax}_k\left(\frac{E_d(m)^T W_a E_k(m)}{\sqrt{L}}\right)\]

Medication-Aware Psi: \[\psi_{k,d}(m) = W_\psi^T C_{d,k}(m) + b_\psi + W_{med}^T m_d\]

Model Validation

Key validation metrics showing model performance

Key Benefits Summary

Semantic Structure: - Diseases with similar embeddings have similar psi values - Learned relationships are interpretable - Attention weights reveal disease-signature associations

Model-Based Approach: - Maintains principled probabilistic framework - Embeddings provide regularization - Can incorporate external knowledge

Scalability: - Efficient parameter usage - Easy to add new diseases - Handles large-scale datasets

Summary & Next Steps

What We’ve Accomplished: - Replaced direct \(\psi\) learning with hierarchical embeddings - Used attention mechanism to relate diseases to signatures - Created contextualized representations for each disease-signature pair - Maintained model-based approach while adding interpretability

Next Steps: - Implement in our existing model (clust_gp) - Experiment with different embedding dimensions - Add medication effects - Validate on our dataset

For Your Real Application

With D=350, K=20, use L=5-8 for optimal parameter efficiency while maintaining interpretability.